# -*- coding: utf-8 -*-
"""
Created on Fri Jul 28 12:12:47 2023

@author: Andrei Sontag

Code to perform the curvature analysis of the voting data for the straight-line
manifold projection. This corresponds to the figure in Section SI.6.1 of the 
Supplementary Information.
"""

# importing required libraries
import numpy as np
import pandas as pd
import os, fnmatch
import matplotlib.pyplot as plt

plt.close('all')

os.chdir(r'./')

# Find data files
fileOfDirectory = os.listdir('.')
pattern = r'*.csv'
files = []
for filename in fileOfDirectory:
        if fnmatch.fnmatch(filename, pattern):
                files.append(filename)

# Show the name of the files
print(files)

# Defines the ratios to bin the data, since each group size has different ratios.
# This allows to unify the data. Here we select 30 bins for the data.
ratios = np.arange(-1,1+1/30,1/30)

# Defining arrays to store counts, averages and standard deviations for each bin.
count_r = np.zeros_like(ratios)
mean_Abr = np.zeros_like(ratios)
sqrmean_Abr = np.zeros_like(ratios)

# stores the data for all votes for each option in each experiment at each round
dataA = np.array([])
dataB = np.array([])

for i,file in enumerate(files):
    # read files and stores in a data frame
    df = pd.read_csv(file)
    
    # game has 120 rounds + 1 to make the code more streamlined
    nrounds = 121
    
    total_Avotes=[]
    total_Bvotes=[]
    total_Absvotes=[]
    # Defines the headings to read from the dataframe
    for k in np.arange(1,nrounds):
        total_Avotes = np.append(total_Avotes,r"my_voting.{0:.0f}.group.A_votes".format(k))
        total_Bvotes = np.append(total_Bvotes,r"my_voting.{0:.0f}.group.B_votes".format(k))
        total_Absvotes = np.append(total_Absvotes,r"my_voting.{0:.0f}.group.abstain_votes".format(k))
    
    # gets the data from the data frame about the number of A votes, B votes and Abstentions at each round
    tAvotes = df[total_Avotes].to_numpy()
    tAvotes = tAvotes[0,:]
    tBvotes = df[total_Bvotes].to_numpy()
    tBvotes = tBvotes[0,:]
    tAbsvotes = df[total_Absvotes].to_numpy()
    tAbsvotes = tAbsvotes[0,:]
    
    # Defines the group size
    N = int(tAvotes[0] + tBvotes[0] + tAbsvotes[0])
    
    # brings the data from each experiment in a single array
    dataA = np.append(dataA,tAvotes/N)
    dataB = np.append(dataB,tBvotes/N)
    
    mean_abs = np.zeros(2*N+1)
    count_mabs = np.zeros(2*N+1)
    
    # for each round
    for k in np.arange(0,nrounds-1):
        # computes sum and difference of A and B votes in the round
        diff = int(tAvotes[k]-tBvotes[k])
        sumv = int(tAvotes[k]+tBvotes[k])
        
        # computes the projection of the point 
        round_ratio = diff/sumv
        # finds the index the projection belongs to
        idx = np.argmin(abs(ratios-round_ratio))
        
        # adds the proportion of abstentions to the corresponding bin
        mean_Abr[idx] += tAbsvotes[k]/N
        # also adds to the squared mean
        sqrmean_Abr[idx] += (tAbsvotes[k]/N)**2
        # adds to the counter of the number of data points in that bin
        count_r[idx] += 1
    
#%%
### DATA ANALYSIS

# Symmetrise data
mean_Abr2 = (mean_Abr+mean_Abr[::-1])
# Find average number of abstentions per bin
mean_Abr = mean_Abr/count_r
mean_Abr2 = mean_Abr2/(count_r+count_r[::-1])

# Find standard deviation of the data
sqr_Abr = np.sqrt((sqrmean_Abr-count_r*mean_Abr**2)/(count_r-1))
# Symmetrise error
sqr_Abr2 = np.sqrt(((sqrmean_Abr+sqrmean_Abr[::-1])-(count_r+count_r[::-1])*mean_Abr2**2)/(count_r+count_r[::-1]-1))

# Error of the mean (used to define 95% confidence intervals)
merror = sqr_Abr2/np.sqrt(count_r+count_r[::-1])

# Get rid of points with <2 data points, those dont provide enough data to estimate averages and errors
mAbr = mean_Abr2[count_r+count_r[::-1]>1]
sqrAbr = sqr_Abr2[count_r+count_r[::-1]>1]
mEr = merror[count_r+count_r[::-1]>1]
rats = ratios[count_r+count_r[::-1]>1]

#%% Data fitting
from scipy.optimize import curve_fit

# Straight-line in euclidian coordinates
def func_str(x,a):
    return 1-a-x

# hyperbola in euclidian coordinates
def func_curv(x,a):
    return (1-x)/(1+a*x)

# 1-x-y - a*xy = b (mix between line and curved)
def func_curvb(x,a,b):
    return (1-x-b)/(1+a*x)

# straight-line in projected coordinates
def funline(z,a):
    return 0.5*(1-a)*np.sqrt(2*(1+z**2))

# hyperbola in projected coordinates
def funcurv(z,a):
    numer = -2/np.sqrt(2*(1+z**2)) + np.sqrt( 2/(1+z**2) + 2*a*(1-z**2)/(1+z**2) )
    den = a*(1-z**2)/(1+z**2)
    return numer/den

# 1-x-y - a*xy = b in projected coordinates
def funcurvb(z,a,b):
    numer = -2/np.sqrt(2*(1+z**2)) + np.sqrt( 2/(1+z**2) + 2*a*(1-b)*(1-z**2)/(1+z**2) )
    den = a*(1-z**2)/(1+z**2)
    return numer/den

xdata = rats # alignment values
xdata[0] = xdata[0]+0.0000001 # to not divide by zero in 1/(1-z^2)
xdata[-1] = xdata[-1]-0.0000001 # to not divide by zero in 1/(1-z^2)
ydata = np.sqrt(0.5*(1-mAbr)**2*(1+rats**2)) # number of votes in the new variables
dy = mEr*(1-mAbr)*(1+rats**2)/(2*ydata) # error propagation of the data in the new variable

# least squares fit in the new variables, should minimise distance in the bin, rather than in euclidian coordinates
line_fit = curve_fit(funline, xdata, ydata, sigma=dy, full_output=True)
curved_fit = curve_fit(funcurv, xdata, ydata, sigma=dy, full_output=True)
curved_fitb = curve_fit(funcurvb, xdata, ydata, sigma=dy, full_output=True)

# parameter values
popt = line_fit[0]
popc = curved_fit[0]
popcb = curved_fitb[0]

# errors
pcovt = line_fit[1]
pcovc = curved_fit[1]
pcovb = curved_fitb[1]

# normalised residuals
res_line = line_fit[2]['fvec']
res_curv = curved_fit[2]['fvec']
res_curvb = curved_fitb[2]['fvec']

# Plot residuals to check for obvious patterns
plt.figure(figsize=(8,8))
plt.scatter(rats,res_line,color='k',label='straight-line')
plt.scatter(rats,res_curv,marker='s',color='royalblue',label='hyperbola')
plt.scatter(rats,res_curvb,marker='^',color='darkorange',label='2-parameter curve')
plt.plot(rats,rats*0,'k--')
plt.legend()
plt.title('Normalised residuals')
plt.show()

# Looks like the straight-manifold has an obvious "curved" pattern in the residuals.
# Hence, we might discard the straight-manifold as an option.
# The fitting of 1-x-y - a*xy = b has a similar pattern as the straight-line, but less pronouced.
# The hyperbola looks alright, yet, there are 4 extreme outliers (or 2, since they are symmetric).
# We can then remove outliers to improve fitting of the data.

#%%
# Removing outliers
xdata = np.delete(xdata,[0,-1,2,-3])
ydata = np.delete(ydata,[0,-1,2,-3])
dy = np.delete(dy,[0,-1,2,-3])
xdata = np.delete(xdata,[2,-3])
ydata = np.delete(ydata,[2,-3])
dy = np.delete(dy,[2,-3])

# fit again with outliers removed
line_fit = curve_fit(funline, xdata, ydata, sigma=dy, full_output=True)
curved_fit = curve_fit(funcurv, xdata, ydata, sigma=dy, full_output=True)
curved_fitb = curve_fit(funcurvb, xdata, ydata, sigma=dy, full_output=True)

# parameters
popt = line_fit[0]#         0.14442357
popc = curved_fit[0]#       1.02078634
popcb = curved_fitb[0]# a = 0.88525392, b = 0.01962893

# errrors
pcovt = line_fit[1]#        0.00468325 (square-root)
pcovc = curved_fit[1]#      0.0246279  (square-root)
pcovb = curved_fitb[1]#     0.10073551, 0.01421673

# normalised residuals
res_line = line_fit[2]['fvec']#       sum: 441.59
res_curv = curved_fit[2]['fvec']#     sum: 136.34
res_curvb = curved_fitb[2]['fvec']#   sum: 131.34

# Plot residuals
plt.figure(figsize=(8,8))
plt.scatter(xdata,res_line,color='k',label='straight-line')
plt.scatter(xdata,res_curv,marker='s',color='royalblue',label='hyperbola')
plt.scatter(xdata,res_curvb,marker='^',color='darkorange',label='2-parameter curve')
plt.plot(rats,rats*0,'k--')
plt.legend()
plt.title('Normalised residuals')
plt.show()

# Again, the residuals of the straight-line fit have a curved pattern.
# However, the residuals of the two curved fits are fine.

#%% 
# Plotting data and residuals together

rats = np.delete(rats,[0,-1,2,-3])
mAbr = np.delete(mAbr,[0,-1,2,-3])
mEr = np.delete(mEr,[0,-1,2,-3])
rats = np.delete(rats,[2,-3])
mAbr = np.delete(mAbr,[2,-3])
mEr = np.delete(mEr,[2,-3])

#%%
### Plotting ###
zscore = 1.96

import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 20})
plt.rcParams['axes.linewidth'] = 3

fig, axes = plt.subplots(1, 2, figsize=(16, 8))

axes[0].scatter(0.5*(1+rats)*(1-mAbr),0.5*(1-rats)*(1-mAbr),color='k')
axes[0].plot(0.5*(1+rats)*(1-(mAbr+zscore*mEr)),0.5*(1-rats)*(1-(mAbr+zscore*mEr)),color='k')
axes[0].plot(0.5*(1+rats)*(1-(mAbr-zscore*mEr)),0.5*(1-rats)*(1-(mAbr-zscore*mEr)),color='k')
axes[0].plot((ratios+1)/2,1-(ratios+1)/2,'k--')
axes[0].plot((ratios+1)/2,1-(ratios+1)/2-popt[0], color='red',label=r'1-x-y = $\alpha$')
axes[0].plot((ratios+1)/2, func_curv((ratios+1)/2,popc[0]),color='royalblue',label=r'1-x-y-$\gamma$xy = 0')
axes[0].plot((ratios+1)/2, func_curvb((ratios+1)/2,popcb[0],popcb[1]),color='orange',label=r'1-x-y-$\gamma$xy = $\beta$')
axes[0].axis([0,1,0,1])
axes[0].plot((ratios+1)/2,1-(ratios+1)/2,'k--')
axes[0].set(xlabel='X votes',ylabel='Y votes')
#axes[0].set_title('Projected data (all experiments)')
axes[0].legend()

axes[1].scatter(xdata,res_line,color='red',label=r'1-x-y = $\alpha$')
axes[1].scatter(xdata,res_curv,marker='s',color='royalblue',label=r'1-x-y-$\gamma$xy = 0')
axes[1].scatter(xdata,res_curvb,marker='^',color='darkorange',label=r'1-x-y-$\gamma$xy = $\beta$')
axes[1].plot(rats,rats*0,'k--')
axes[1].set(xlabel='Projected alignment',ylabel='Normalised residual')
#axes[1].set_title('Normalised residuals')
axes[1].legend(loc='lower center',fontsize=18)

plt.show()